data <- mx.symbol.Variable("data")
label <- mx.symbol.Variable("label")
label <- mx.symbol.identity(label, name="label")
fc1 <- mx.symbol.FullyConnected(data=data, num_hidden=1, name="fc1")
fc1 <- mx.symbol.Reshape(data=fc1, shape=c(0), name="fc1_reshape")
perc_err <- mx.symbol.abs(fc1/label-1, name="perc_error")
custom_loss <- mx.symbol.MakeLoss(perc_err, name="loss")
graph.viz(custom_loss, direction="LR", graph.height.px = 160)
model_reg<- mx.model.FeedForward.create(symbol = custom_loss,
X = train_x, y = train_y,
eval.data = list(data=test_x, label=test_y),
ctx = mx.cpu(), num.round = 24,
array.batch.size = 32,
optimizer = "sgd",
learning.rate=0.000001,
momentum=0.9,
wd=0.0001, epoch.end.callback = mx.callback.log.train.metric(1), eval.metric = mx.metric.rmse
)
## Warning in mx.model.select.layout.train(X, y): Auto detect layout of input matrix, use rowmajor..
## Start training with 1 devices
## [1] Train-rmse=24.1800087672059
## [1] Validation-rmse=22.6335869929121
## [2] Train-rmse=24.0300878467801
## [2] Validation-rmse=22.6708376493816
## [3] Train-rmse=24.0707992455687
## [3] Validation-rmse=22.7176125994856
## [4] Train-rmse=24.1181655824697
## [4] Validation-rmse=22.7686102278079
## [5] Train-rmse=24.1685198924117
## [5] Validation-rmse=22.8215426928708
## [6] Train-rmse=24.2202787888647
## [6] Validation-rmse=22.8754213361408
## [7] Train-rmse=24.2720386027599
## [7] Validation-rmse=22.9282619103808
## [8] Train-rmse=24.3217555028644
## [8] Validation-rmse=22.9764323660188
## [9] Train-rmse=24.3670866650711
## [9] Validation-rmse=23.0190863276446
## [10] Train-rmse=24.4068048566368
## [10] Validation-rmse=23.0559792096579
## [11] Train-rmse=24.439853562234
## [11] Validation-rmse=23.0872859699407
## [12] Train-rmse=24.4668716303623
## [12] Validation-rmse=23.1133190134204
## [13] Train-rmse=24.4886991372452
## [13] Validation-rmse=23.1341658301637
## [14] Train-rmse=24.5058598419352
## [14] Validation-rmse=23.1511550137497
## [15] Train-rmse=24.5206837982417
## [15] Validation-rmse=23.1661712503112
## [16] Train-rmse=24.5339009113935
## [16] Validation-rmse=23.1800148831064
## [17] Train-rmse=24.5461397612431
## [17] Validation-rmse=23.1927502141254
## [18] Train-rmse=24.5574676320726
## [18] Validation-rmse=23.2044778470468
## [19] Train-rmse=24.5676633243144
## [19] Validation-rmse=23.214840210744
## [20] Train-rmse=24.5766600105923
## [20] Validation-rmse=23.2241971884832
## [21] Train-rmse=24.5848173914259
## [21] Validation-rmse=23.2326224167931
## [22] Train-rmse=24.5920757539627
## [22] Validation-rmse=23.2402938266856
## [23] Train-rmse=24.5984723358449
## [23] Validation-rmse=23.2471215105821
## [24] Train-rmse=24.6042160590102
## [24] Validation-rmse=23.2533549538859